package com.sharelords.biz.creator.memory;

import com.sharelords.biz.constant.Constant;
import com.sharelords.biz.creator.compile.CompilerCreator;
import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.Log;

import javax.tools.*;
import java.io.File;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.net.URI;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.util.*;

/**
 * @author 千古龙少
 * @Description: 内存class类加载器
 * @date 2020年6月14日 下午8:35:26
 */
public class MemoryClassLoader extends URLClassLoader {

    private Map<String, byte[]> classBytes = new HashMap<>();

    private static final MemoryClassLoader DEFAULT_LOADER = new MemoryClassLoader();

    private static SpringJavaFileManager javaFileManager;

    private static List<String> optionList;

    private MemoryClassLoader() {
        super(new URL[0], MemoryClassLoader.class.getClassLoader());
        init();
    }

    /**
     * 初始化
     */
    private static void init() {
        if (javaFileManager == null) {
            javaFileManager = getStandardFileManager(null, null, null);
        }
        if (optionList == null) {
            optionList = getOptionList();
        }
    }

    @Override
    public Class<?> findClass(String name) throws ClassNotFoundException {
        byte[] buf = classBytes.get(name);
        if (buf == null) {
            return super.findClass(name);
        }

        mapRemove(classBytes, name);
        return defineClass(name, buf, 0, buf.length);
    }

    /**
     * 获取默认的类加载器
     *
     * @return 类加载器对象
     */
    public static MemoryClassLoader getInstance() {
        return DEFAULT_LOADER;
    }

    /**
     * 注册Java 字符串到内存类加载器中
     *
     * @param className 类名字
     * @param javaSrc   Java文件字符串
     */
    public void registerJava(String className, String javaSrc) {
        this.classBytes.putAll(Objects.requireNonNull(this.compile(className, javaSrc)));
    }

    /**
     * 获取有效的实现类class，屏蔽掉内部类，返回map中key：接口全限定名，value:实现类
     *
     * @return Map<String, Class < ?>> 接口全限定名与实现类class对应关系
     * @throws ClassNotFoundException 异常
     */
    public Map<String, Class<?>> getPreparedBizImplClasses() throws ClassNotFoundException {
        Map<String, Class<?>> resultMap = new HashMap<>(16);
        Iterator<String> iterator = classBytes.keySet().iterator();
        while (iterator.hasNext()) {
            String className = iterator.next();

            Class<?> clazz = this.getClass(className);
            iterator = classBytes.keySet().iterator();

            if (!className.matches(".*[$][1-9][0-9]*$")) {
                // 非内部类编译的class
                Class<?>[] interfaces = clazz.getInterfaces();
                if (interfaces.length > 0) {
                    for (Class<?> interfaceClazz : interfaces) {
                        resultMap.put(interfaceClazz.getName(), clazz);
                    }
                }
            }
        }

        return resultMap;
    }

    /**
     * 获取内存中对应的class文件集合
     *
     * @param className
     * @param javaSrc
     * @return
     */
    private Map<String, byte[]> compile(String className, String javaSrc) {
        JavaCompiler compiler = CompilerCreator.getCompiler();
        if (compiler == null) {
            return null;
        }

        return getMemoryClassByteArrayMap(className, javaSrc, compiler, javaFileManager);
    }

    /**
     * 自定义Java文件管理器
     *
     * @param listener
     * @param locale
     * @param charset
     * @return
     */
    private static SpringJavaFileManager getStandardFileManager(DiagnosticListener<? super JavaFileObject> listener, Locale locale, Charset charset) {
        Context context = new Context();
        context.put(Locale.class, locale);
        if (listener != null) {
            context.put(DiagnosticListener.class, listener);
        }

        PrintWriter writer = charset == null ? new PrintWriter(System.err, true) : new PrintWriter(new OutputStreamWriter(System.err, charset), true);
        context.put(Log.outKey, writer);

        return new SpringJavaFileManager(context, true, charset);
    }

    /**
     * 获取内存中class文件对应的字节码，key-class全限定名，value--class对应的字节码
     *
     * @param className  类名
     * @param javaSrc    class字符串文件
     * @param compiler   java编译器
     * @param stdManager java文件管理器
     * @return
     */
    private static Map<String, byte[]> getMemoryClassByteArrayMap(String className, String javaSrc, JavaCompiler compiler, StandardJavaFileManager stdManager) {
        try (MemoryJavaFileManager manager = new MemoryJavaFileManager(stdManager)) {
            JavaFileObject javaFileObject = makeStringSource(className, javaSrc);
            JavaCompiler.CompilationTask task = compiler.getTask(null, manager, null, optionList, null, Collections.singletonList(javaFileObject));
            Boolean call = task.call();
            if (call != null && call) {
                return manager.getClassBytes();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

        return null;
    }

    /**
     * 创建java文件字符串资源对象
     *
     * @param className
     * @param code
     * @return
     */
    private static JavaFileObject makeStringSource(String className, final String code) {
        String classPath = className.replace('.', '/') + JavaFileObject.Kind.SOURCE.extension;
        return new SimpleJavaFileObject(URI.create(Constant.STRING_FLAG + classPath), JavaFileObject.Kind.SOURCE) {
            @Override
            public CharBuffer getCharContent(boolean ignoreEncodingErrors) {
                return CharBuffer.wrap(code);
            }
        };
    }

    /**
     * 获取动态编译参数
     *
     * @return
     */
    private static List<String> getOptionList() {
        List<String> optionList = new ArrayList<>();

        optionList.add("-classpath");
        StringBuilder sb = new StringBuilder();
        URLClassLoader urlClassLoader = (URLClassLoader) Thread.currentThread().getContextClassLoader();
        for (URL url : urlClassLoader.getURLs()) {
            sb.append(url.getFile()).append(File.pathSeparator);
        }
        optionList.add(sb.toString());

        return optionList;
    }

    /**
     * 开放findClass 给外部使用
     *
     * @param name classname
     * @return class对象
     */
    private Class<?> getClass(String name) throws ClassNotFoundException {
        return this.findClass(name);
    }

    /**
     * map元素移除
     *
     * @param map        map集合
     * @param removeElem 需要移除的key
     */
    private static <T> void mapRemove(Map<String, T> map, String removeElem) {
        Iterator<String> iterator = map.keySet().iterator();
        while (iterator.hasNext()) {
            String key = iterator.next();
            if (removeElem.equals(key)) {
                iterator.remove();
                break;
            }
        }
    }

}
